import json
import umap
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn import metrics
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
import statsmodels.api as sm
from gensim.models.doc2vec import Doc2Vec
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import warnings

warnings.filterwarnings("ignore")

#---------Loading User Vectors----------#
metadata = pd.read_csv('data/user_metadata.csv')
with open('data/user_map.json') as fin:
    hasher = json.load(fin)

model = Doc2Vec.load('data/USER_EMBEDDINGS')
doctags = [i for i in model.docvecs.doctags]

metadata = metadata[metadata.hash_.isin(doctags)]
metadata = metadata[metadata.english_speaker==1]
# Subset of users having posted once a day on average.
metadata = metadata[(metadata.freq>=50.0)]

n = metadata.shape[0]
metadata['idx'] = [i for i in range(n)]
users = metadata.hash_.tolist()
freqs = metadata.freq.tolist()
userlookup = {s:i for i,s in enumerate(users)}

# Create user matrix
y = np.zeros((n, 200))
for i in range(n):
    y[i,:] = model.docvecs[users[i]]

# Apply pca
pca = PCA(n_components=50, random_state=42)
x = pca.fit_transform(y)

# Compute UMAP
umap_model = umap.UMAP(
    n_neighbors=200,
    min_dist=0.0,
    n_components=2,
    random_state=21)
z = umap_model.fit_transform(x)

# Reorienting x-axis for visual convenience
z[:,0] = z[:,0]*(-1)
# Centering axes (optional)
z[:,0] = z[:,0] - z[:,0].mean()
z[:,1] = z[:,1] - z[:,1].mean()

# Label coordinates for visualizations:
with open('data/umap_coordinates.json') as fin:
    coords = json.load(fin)

location_coords = {}
for key, val in coords.items():
    xs = []; ys = [];
    for idx, u in enumerate(val['users']):
        if u in userlookup:
            xs.append(z[userlookup[u],0])
            ys.append(z[userlookup[u],1])
    location_coords[key] = (np.mean(xs), np.mean(ys))

#=======================================================================#
#
# Figure 2: Mapping of Twitter Users during the 2019 Canadian Election
#
#=======================================================================#

plt.rc('axes', titlesize=20)
plt.rc('axes', labelsize=20)
red_patch = Line2D([0],[0],marker='o',markersize=20,color=(1,0,0,0.8),label='Bots',linestyle='None')
blue_patch = Line2D([0],[0],marker='o',markersize=20,color=(0,0,1,0.3),label='Humans',linestyle='None')

bots = (metadata.bot=='bot')
plt.figure(figsize=(21, 15))

# Adding labels and arrows for interpretation
for user, label in [(hasher['justintrudeau'], 'Justin Trudeau'),
                    (hasher['thejagmeetsingh'], 'Jagmeet Singh (NDP)'),
                    (hasher['elizabethmay'], 'Elizabeth May (Greens)'),
                    (hasher['maximebernier'], 'Maxime Bernier (PPC)')]:
    i = userlookup[user]
    ix = z[i,0]; iy = z[i,1]
    plt.annotate(label, xy=(ix, iy), xytext=(-60, 20), fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

i = userlookup[hasher['cbcnews']]
ix = z[i,0]; iy = z[i,1]
plt.annotate('CBC News',xy=(ix, iy), xytext=(-20, -50),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

i = userlookup[hasher['andrewscheer']]
ix = z[i,0]; iy = z[i,1]
plt.annotate('Andrew Scheer',xy=(ix, iy), xytext=(40, -90),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

ix, iy = location_coords['Foreign (Refugees)']
plt.annotate('Foreign Campaign (Refugees)', xy=(ix, iy), xytext=(-60, 0),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

ix, iy = location_coords['Junk News']
plt.annotate('"Junk News"', xy=(ix, iy), xytext=(-20, -50),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

ix, iy = location_coords['Liberal']
plt.annotate('Liberals (Pro-Trudeau)', xy=(ix, iy), xytext=(160, 20),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

ix, iy = location_coords['Conservative']
plt.annotate('Conservatives (Anti-Trudeau)', xy=(ix, iy), xytext=(140, 90),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

plt.annotate('Suspected Foreign Accounts', xy=(-1.25, -3.25), xytext=(-40, -5),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

plt.gca().add_patch(mpatches.Rectangle((-1.25,-4.5),3.0,2.75,linewidth=1,edgecolor='r',facecolor='none'));

plt.scatter(z[bots, 0],
            z[bots, 1],
            edgecolor='red', facecolor=(1,0,0,0.8), linewidth=0.1,
            s=[x/50 for x in metadata.freq[bots]]);
plt.scatter(z[~bots, 0],
            z[~bots, 1],
            edgecolor='blue', facecolor=(0,0,1,0.1), linewidth=0.1,
            s=[x/50 for x in metadata.freq[~bots]]);

plt.xlabel("UMAP Dimension 1");
plt.ylabel("UMAP Dimension 2");
plt.legend([red_patch, blue_patch], ['Suspected Bots','Humans'], fontsize=20, loc='lower left');
plt.savefig('figures/figure2.jpeg', dpi=200, bbox_inches='tight');

#---------Figure 2b----------#
# Geotags
filters = ((metadata.origin=='canadian'))

plt.figure(figsize=(14, 10))
red_patch = Line2D([0],[0],marker='o',markersize=20,color=(102/255,0,51/255,0.9),label='',linestyle='None')
blue_patch = Line2D([0],[0],marker='o',markersize=20,color=(0,0,1,0.3),label='',linestyle='None')

plt.scatter(z[filters, 0],
            z[filters, 1],
            edgecolor=(102/255,0,51/255), facecolor=(102/255,0,51/255,0.9), linewidth=1,
            s=30, marker="o");
plt.scatter(z[~filters, 0],
            z[~filters, 1],
            edgecolor='blue', facecolor=(0,0,1,0.1), linewidth=0.1,
            s=1);

plt.annotate('Suspected Foreign Accounts', xy=(-1.25, -3.25), xytext=(-40, -5),fontsize=14,
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='black', alpha=0.1),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'));

plt.gca().add_patch(mpatches.Rectangle((-1.25,-4.5),3.0,2.75,linewidth=1,edgecolor='r',facecolor='none'));

plt.xlabel("UMAP Dimension 1");
plt.ylabel("UMAP Dimension 2");
plt.legend([red_patch, blue_patch], ['Users geotagged within Canada'], fontsize=16, loc='upper right');
plt.savefig('figures/figure2b.jpeg', dpi=100, bbox_inches='tight');

#---------Figure 2c----------#
# Climate + far-right shares
filters = ((metadata.bot=='bot') & (metadata['farright']>0))
filters2 = ((metadata.bot=='bot') & (metadata['climate']>0))

plt.figure(figsize=(14, 10))
red_patch = Line2D([0],[0],marker='o',markersize=20,color=(102/255,0,51/255,0.8),label='',linestyle='None')
green_patch = Line2D([0],[0],marker='o',markersize=20,color='green',label='',linestyle='None')
blue_patch = Line2D([0],[0],marker='o',markersize=20,color=(0,0,1,0.3),label='',linestyle='None')
plt.scatter(z[filters, 0],
            z[filters, 1],
            edgecolor=(102/255,0,51/255), facecolor=(102/255,0,51/255,0.9), linewidth=1,
            s=10*metadata['farright'][filters], marker="o");
plt.scatter(z[filters2, 0],
            z[filters2, 1],
            edgecolor='green', facecolor='green', linewidth=1,
            s=10*metadata['climate'][filters2], marker="o");
plt.scatter(z[~(filters | filters2), 0],
            z[~(filters | filters2), 1],
            edgecolor='blue', facecolor=(0,0,1,0.1), linewidth=0.1,
            s=1);

plt.xlabel("UMAP Dimension 1");
plt.ylabel("UMAP Dimension 2");
plt.legend([red_patch, green_patch, blue_patch],
    ['Bots sharing URLs from far-right domains','Bots sharing climate-related URLs', 'All other users'],
    fontsize=16, loc='lower left');
plt.savefig('figures/figure2c.jpeg', dpi=100, bbox_inches='tight');

#=======================================================================#
#
# Figure A2: Domain Sharing Preferences by Cluster
#
#=======================================================================#

def generate_subfigure(DOMAIN, OUTPUTNAME, LABEL, BOTONLY=False):
    if BOTONLY:
        filters = ((metadata.bot=='bot') & (metadata[DOMAIN]>0))
    else:
        filters = (metadata[DOMAIN]>0)
    plt.figure(figsize=(14, 10))
    red_patch = Line2D([0],[0],marker='o',markersize=20,color=(102/255,0,51/255,0.8),label='',linestyle='None')
    blue_patch = Line2D([0],[0],marker='o',markersize=20,color=(0,0,1,0.3),label='',linestyle='None')
    plt.scatter(z[filters, 0],
                z[filters, 1],
                edgecolor=(102/255,0,51/255), facecolor=(102/255,0,51/255,0.9), linewidth=1,
                s=2*metadata[DOMAIN][filters], marker="o");
    plt.scatter(z[~filters, 0],
                z[~filters, 1],
                edgecolor='blue', facecolor=(0,0,1,0.1), linewidth=0.1,
                s=1);

    plt.xlabel("UMAP Dimension 1");
    plt.ylabel("UMAP Dimension 2");
    plt.legend([red_patch, blue_patch], [LABEL, 'All other users'], fontsize=16, loc='lower left');
    plt.savefig('figures/figureA2' + OUTPUTNAME + '.jpeg', dpi=100, bbox_inches='tight');

subfigures = [('pressprogress', 'a', 'Users sharing from pressprogress.ca', False),
            ('postmil', 'b', 'Users sharing from thepostmillennial.com', False),
            ('buffalo', 'c', 'Bots sharing from buffalochronicle.com', True)]

for dom, outp, lab, t in subfigures:
    generate_subfigure(dom, outp, lab, t)

#=======================================================================#
#
# Table 3: Top Words by UMAP Cluster
#
#=======================================================================#

labs = ['Liberal', 'Greens', 'NDP', 'PPC', 'Conservative', 'Foreign (Refugees)', 'Junk News']
fitted = pd.DataFrame({'screen_name':metadata.hash_, 'd0':z[:,0], 'd1':z[:,1],
                            'bot':metadata.bot, 'freq': metadata.freq})
fitted = fitted.sort_values(by='freq', ascending=False).reset_index(drop=True)

topwords=[]
for l in labs:
    h1,h2,v1,v2 = coords[l]['coords']
    topl = []
    # Find users falling in specific area of the UMAP projection:
    tempusers = fitted[(fitted.d0>h1) & (fitted.d0<h2) & (fitted.d1>v1) & (fitted.d1<v2)].screen_name.tolist()
    # Retrieve average embedding for those users:
    A = np.zeros((len(tempusers), 200))
    for u in range(len(tempusers)):
        A[u,:] = model.docvecs[tempusers[u]]
    anchor = A.mean(axis=0)
    # Retrieve 5 most similar words:
    for w, s in model.wv.similar_by_vector(anchor, topn=5, restrict_vocab=5000):
        topl.append(w)
    topwords.append((l, topl))

with open('tables/table3.txt','w') as fout:
    print("Table 3: Top Words by UMAP Cluster\n",file=fout)
    for l, wlist in topwords:
        print("Topic %s" %l, file=fout)
        for w in wlist:
            print("\t" + w, file=fout)
        print('\n',file=fout)

#=======================================================================#
#
# Table 5: Predicting the Party Affiliation of Known Candidates
#
#=======================================================================#

# Ground Truth (Candidates)
candidates = pd.read_csv('data/candidate_usernames.csv')

clusters = ['Conservative', 'News', 'Greens', 'PPC', 'Foreign Accounts', 'Liberal', 'NDP']

# Add leaders to partisan groups:
coords['Liberal']['users'] = coords['Liberal']['users'] + [hasher['justintrudeau']]
coords['Conservative']['users'] = coords['Conservative']['users'] + [hasher['andrewscheer']]
coords['NDP']['users'] = coords['NDP']['users'] + [hasher['thejagmeetsingh']]
coords['Greens']['users'] = coords['Greens']['users'] + [hasher['elizabethmay']]
coords['PPC']['users'] = coords['PPC']['users'] + [hasher['maximebernier']]
coords['News'] = {'users': [hasher['cbcnews']]}
coords['Foreign Accounts']['users'] = coords['Foreign Accounts']['users'] + coords['Foreign (Refugees)']['users'] + coords['Junk News']['users']

# Create anchors for partisan group attribution
anchors = []
for lab in clusters:
    A = np.zeros((len(coords[lab]['users']),200))
    for idx, u in enumerate(coords[lab]['users']):
        A[idx, :] = model.docvecs[u]
    anchors.append((lab, A.mean(axis=0)))

# Assign group based on similarity
tmp = metadata
for lab, vector in anchors:
    sims = cosine_similarity(y, vector.reshape(1, -1))
    tmp[lab] = sims

tmp['partisan_group'] = tmp[clusters].idxmax(axis=1)

clf = candidates.merge(tmp, on='hash_', how='inner')

with open('tables/table5.txt', 'w') as fout:
    print("Table 5: Predicting the Party Affiliation of Known Candidates\n", file=fout)
    print(pd.crosstab(clf.party_gold, clf.partisan_group).to_string(index=True, float_format=lambda x: '%0.1f'%x), file=fout)
    print("\nAccuracy Score, Party Affiliation of Candidates: %0.3f" %(metrics.accuracy_score(clf.party_gold, clf.partisan_group)*100), file=fout)
    print("\nWeighted F1 Score, Party Affiliation of Candidates: %0.3f" %(metrics.f1_score(clf.party_gold, clf.partisan_group, average='weighted')), file=fout)

#=======================================================================#
#
# Table 6: Estimated Size of Partisan Groups vs Popular Vote
#
#=======================================================================#

# Print distribution excluding assumed non-voters
filt=['Foreign (Refugees)', 'News', 'Foreign Accounts','Junk News']
party_table = tmp[~tmp.partisan_group.isin(filt)].partisan_group.value_counts()/tmp[~tmp.partisan_group.isin(filt)].shape[0]*100
party_table = pd.DataFrame(party_table).reset_index()
party_table['Popular Vote'] = [34.4,33.1,15.9,1.6,6.5]
party_table.columns = ['Party Group', 'Estimated Proportion', 'Popular Vote']

with open('tables/table6.txt', 'w') as fout:
    print("Table 6: Estimated Size of Partisan Groups vs Popular Vote\n", file=fout)
    print(party_table.to_string(index=False, float_format=lambda x: '%0.1f'%x), file=fout)

#=======================================================================#
#
# Table 4: Odds of Observing Social Bots and Canadian Geotags, by Cluster
#
#=======================================================================#

# Group proportions and bot density
groupprops = tmp['partisan_group'].value_counts()/tmp.shape[0]
botprops = tmp[tmp.bot=='bot'].groupby('partisan_group').bot.value_counts()/tmp[tmp.bot=='bot'].shape[0]

# Compute odds ratios
ntmp = tmp
ntmp['botbinary'] = [1 if x=='bot' else 0 for x in ntmp.bot]
ctmp = tmp[pd.notnull(tmp.origin)].reset_index(drop=True)
ctmp['canadian'] = [1 if x=='canadian' else 0 for x in ctmp.origin]

results = []
for lab in clusters:
    ntmp[lab] = [1 if x==lab else 0 for x in ntmp.partisan_group]
    ctmp[lab] = [1 if x==lab else 0 for x in ctmp.partisan_group]
    # Odds of a bot v human, by group
    tab = pd.crosstab(ntmp[lab], ntmp['botbinary'])
    ntab=sm.stats.Table2x2(tab)
    ODR = ntab.oddsratio
    ODR_P = ntab.oddsratio_pvalue()
    # Odds of Canadian geotag v foreign geotag, given geotags
    ctab = pd.crosstab(ctmp[lab], ctmp['canadian'])
    ctab=sm.stats.Table2x2(ctab)
    CODR = ctab.oddsratio
    CODR_P = ctab.oddsratio_pvalue()
    # Group proportions
    bP = botprops[lab].tolist()[0]
    gP = groupprops[lab]
    results.append((lab, '%0.2f' %(gP*100), '%0.2f' %(bP*100), '%0.3f' %ODR, '%0.3f' %ODR_P,
                   '%0.3f' %CODR, '%0.3f' %CODR_P))

RES = pd.DataFrame(results, columns=['Group', 'Percent (%)', 'Bot Density (%)',
                                     'Bot: OR', 'Bot: p', 'Can: OR', 'Can: p'])

# Sort results for presentation
sorter = {'Liberal':1, 'Conservative':0, 'NDP':2, 'PPC':3, 'Greens':4, 'News':5,
          'Foreign Accounts':6, 'Junk News':7, 'Foreign (Refugees)':8}
RES['sorter'] = RES['Group'].map(sorter)
RES = RES.sort_values(by='sorter', ascending=True)
del RES['sorter']

with open('tables/table4.txt', 'w') as fout:
    print("Table 4: Odds of Observing Social Bots and Canadian Geotags, by Cluster\n", file=fout)
    print(RES.to_string(index=False), file=fout)
